{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\nTrain a linear regression potential\n===================================\n\nIn this tutorial, we train a linear regression model on the descriptors obtained using the\nsymmetry functions.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from kliff.descriptors import SymmetryFunction\nfrom kliff.dataset import Dataset\nfrom kliff.models import LinearRegression\nfrom kliff.calculators import CalculatorTorch\n\n\ndescriptor = SymmetryFunction(\n    cut_name=\"cos\", cut_dists={\"Si-Si\": 5.0}, hyperparams=\"set30\", normalize=True\n)\n\n\nmodel = LinearRegression(descriptor)\n\n# training set\ndataset_name = \"Si_training_set/varying_alat\"\ntset = Dataset()\ntset.read(dataset_name)\nconfigs = tset.get_configs()\nprint(\"Number of configurations:\", len(configs))\n\n# calculator\ncalc = CalculatorTorch(model)\ncalc.create(configs, reuse=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can train a linear regression model by minimizing a loss function as discussed in\n`tut_nn`. But linear regression model has analytic solutions, and thus we can train\nthe model directly by using this feature. This can be achieved by calling the ``fit()``\nfunction of its calculator.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# fit the model\ncalc.fit()\n\n\n# save model\nmodel.save(\"linear_model.pkl\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}